#ifndef _POLYNOMIAL_TRAJ_H
#define _POLYNOMIAL_TRAJ_H

#include <Eigen/Eigen>
#include <vector>

using std::vector;

class PolynomialTraj
{
private:
  vector<double> times;       // time of each segment
  vector<vector<double>> cxs; // coefficient of x of each segment, from n-1 -> 0
  vector<vector<double>> cys; // coefficient of y of each segment
  vector<vector<double>> czs; // coefficient of z of each segment

  double time_sum;
  int num_seg;

  /* evaluation */
  vector<Eigen::Vector3d> traj_vec3d;
  double length;

public:
  PolynomialTraj(/* args */)
  {
  }
  ~PolynomialTraj()
  {
  }

  void reset()
  {
    times.clear(), cxs.clear(), cys.clear(), czs.clear();
    time_sum = 0.0, num_seg = 0;
  }

  void addSegment(vector<double> cx, vector<double> cy, vector<double> cz, double t)
  {
    cxs.push_back(cx), cys.push_back(cy), czs.push_back(cz), times.push_back(t);
  }

  void init();

  vector<vector<double>> getCoef(int axis)
  {
    switch (axis)
    {
    case 0:
      return cxs;
    case 1:
      return cys;
    case 2:
      return czs;
    default:
      std::cout << "\033[31mIllegal axis!\033[0m" << std::endl;
    }

    vector<vector<double>> empty;
    return empty;
  }

  /* for evaluating traj, should be called in sequence!!! */
  double getTimeSum() { return this->time_sum; }

  /* get pos at time t */
  Eigen::Vector3d getPos(double t);

  /* get velocity at time t by derivative of polynomial */
  Eigen::Vector3d getVel(double t);

  /* get acceleration at time t by derivative of polynomial */
  Eigen::Vector3d getAcc(double t);

  vector<Eigen::Vector3d> getTraj()
  {
    double eval_t = 0.0;
    traj_vec3d.clear();
    while (eval_t < time_sum)
    {
      Eigen::Vector3d pt = getPos(eval_t);
      traj_vec3d.push_back(pt);
      eval_t += 0.01;
    }
    return traj_vec3d;
  }

  double getLength()
  {
    length = 0.0;

    Eigen::Vector3d p_l = traj_vec3d[0], p_n;
    for (int i = 1; i < traj_vec3d.size(); ++i)
    {
      p_n = traj_vec3d[i];
      length += (p_n - p_l).norm();
      p_l = p_n;
    }
    return length;
  }

  double getMeanVel()
  {
    double mean_vel = length / time_sum;
  }

  double getAccCost()
  {
    double cost = 0.0;
    int order = cxs[0].size();

    for (int s = 0; s < times.size(); ++s)
    {
      Eigen::Vector3d um;
      um(0) = 2 * cxs[s][order - 3], um(1) = 2 * cys[s][order - 3], um(2) = 2 * czs[s][order - 3];
      cost += um.squaredNorm() * times[s];
    }

    return cost;
  }

  double getJerk()
  {
    double jerk = 0.0;

    /* get jerk */
    for (int s = 0; s < times.size(); ++s)
    {
      Eigen::VectorXd cxv(cxs[s].size()), cyv(cys[s].size()), czv(czs[s].size());
      /* convert coefficient */
      int order = cxs[s].size();
      for (int j = 0; j < order; ++j)
      {
        cxv(j) = cxs[s][order - 1 - j], cyv(j) = cys[s][order - 1 - j], czv(j) = czs[s][order - 1 - j];
      }
      double ts = times[s];

      /* jerk matrix */
      Eigen::MatrixXd mat_jerk(order, order);
      mat_jerk.setZero();
      for (double i = 3; i < order; i += 1)
        for (double j = 3; j < order; j += 1)
        {
          mat_jerk(i, j) =
              i * (i - 1) * (i - 2) * j * (j - 1) * (j - 2) * pow(ts, i + j - 5) / (i + j - 5);
        }

      jerk += (cxv.transpose() * mat_jerk * cxv)(0, 0);
      jerk += (cyv.transpose() * mat_jerk * cyv)(0, 0);
      jerk += (czv.transpose() * mat_jerk * czv)(0, 0);
    }

    return jerk;
  }

  void getMeanAndMaxVel(double &mean_v, double &max_v)
  {
    int num = 0;
    mean_v = 0.0, max_v = -1.0;
    for (int s = 0; s < times.size(); ++s)
    {
      int order = cxs[s].size();
      Eigen::VectorXd vx(order - 1), vy(order - 1), vz(order - 1);

      /* coef of vel */
      for (int i = 0; i < order - 1; ++i)
      {
        vx(i) = double(i + 1) * cxs[s][order - 2 - i];
        vy(i) = double(i + 1) * cys[s][order - 2 - i];
        vz(i) = double(i + 1) * czs[s][order - 2 - i];
      }
      double ts = times[s];

      double eval_t = 0.0;
      while (eval_t < ts)
      {
        Eigen::VectorXd tv(order - 1);
        for (int i = 0; i < order - 1; ++i)
          tv(i) = pow(ts, i);
        Eigen::Vector3d vel;
        vel(0) = tv.dot(vx), vel(1) = tv.dot(vy), vel(2) = tv.dot(vz);
        double vn = vel.norm();
        mean_v += vn;
        if (vn > max_v)
          max_v = vn;
        ++num;

        eval_t += 0.01;
      }
    }

    mean_v = mean_v / double(num);
  }

  void getMeanAndMaxAcc(double &mean_a, double &max_a)
  {
    int num = 0;
    mean_a = 0.0, max_a = -1.0;
    for (int s = 0; s < times.size(); ++s)
    {
      int order = cxs[s].size();
      Eigen::VectorXd ax(order - 2), ay(order - 2), az(order - 2);

      /* coef of acc */
      for (int i = 0; i < order - 2; ++i)
      {
        ax(i) = double((i + 2) * (i + 1)) * cxs[s][order - 3 - i];
        ay(i) = double((i + 2) * (i + 1)) * cys[s][order - 3 - i];
        az(i) = double((i + 2) * (i + 1)) * czs[s][order - 3 - i];
      }
      double ts = times[s];

      double eval_t = 0.0;
      while (eval_t < ts)
      {
        Eigen::VectorXd tv(order - 2);
        for (int i = 0; i < order - 2; ++i)
          tv(i) = pow(ts, i);
        Eigen::Vector3d acc;
        acc(0) = tv.dot(ax), acc(1) = tv.dot(ay), acc(2) = tv.dot(az);
        double an = acc.norm();
        mean_a += an;
        if (an > max_a)
          max_a = an;
        ++num;

        eval_t += 0.01;
      }
    }

    mean_a = mean_a / double(num);
  }

  static PolynomialTraj minSnapTraj(const Eigen::MatrixXd &Pos, const Eigen::Vector3d &start_vel,
                                    const Eigen::Vector3d &end_vel, const Eigen::Vector3d &start_acc,
                                    const Eigen::Vector3d &end_acc, const Eigen::VectorXd &Time);

  static PolynomialTraj one_seg_traj_gen(const Eigen::Vector3d &start_pt, const Eigen::Vector3d &start_vel, const Eigen::Vector3d &start_acc,
                                         const Eigen::Vector3d &end_pt, const Eigen::Vector3d &end_vel, const Eigen::Vector3d &end_acc,
                                         double t);
};

/* ======= definition of inline function ======= */

inline void PolynomialTraj::init()
{
  num_seg = times.size();
  time_sum = 0.0;
  for (int i = 0; i < times.size(); ++i)
  {
    time_sum += times[i];
  }
}

inline Eigen::Vector3d PolynomialTraj::getPos(double t)
{
  if (t < 0)
    std::cerr << "[PolynomialTraj]: t < 0" << std::endl;

  if (t > time_sum)
    std::cerr << "[PolynomialTraj]: t > time_sum" << std::endl;

  /* detetrmine segment num */
  int idx = 0;
  while (times[idx] + 1e-4 < t)
  {
    t -= times[idx];
    ++idx;
  }

  /* evaluation */
  int order = cxs[idx].size();
  Eigen::VectorXd cx(order), cy(order), cz(order), tv(order);
  for (int i = 0; i < order; ++i)
  {
    cx(i) = cxs[idx][i], cy(i) = cys[idx][i], cz(i) = czs[idx][i];
    tv(order - 1 - i) = std::pow(t, double(i));
  }

  Eigen::Vector3d pt;
  pt(0) = tv.dot(cx), pt(1) = tv.dot(cy), pt(2) = tv.dot(cz);
  return pt;
}

inline Eigen::Vector3d PolynomialTraj::getVel(double t)
{
  /* detetrmine segment num */
  int idx = 0;
  while (times[idx] + 1e-4 < t)
  {
    t -= times[idx];
    ++idx;
  }

  /* evaluation */
  int order = cxs[idx].size();
  Eigen::VectorXd vx(order - 1), vy(order - 1), vz(order - 1);

  /* coef of vel */
  for (int i = 0; i < order - 1; ++i)
  {
    vx(i) = double(i + 1) * cxs[idx][order - 2 - i];
    vy(i) = double(i + 1) * cys[idx][order - 2 - i];
    vz(i) = double(i + 1) * czs[idx][order - 2 - i];
  }
  double ts = t;
  Eigen::VectorXd tv(order - 1);
  for (int i = 0; i < order - 1; ++i)
    tv(i) = pow(ts, i);

  Eigen::Vector3d vel;
  vel(0) = tv.dot(vx), vel(1) = tv.dot(vy), vel(2) = tv.dot(vz);
  return vel;
}

inline Eigen::Vector3d PolynomialTraj::getAcc(double t)
{
  /* detetrmine segment num */
  int idx = 0;
  while (times[idx] + 1e-4 < t)
  {
    t -= times[idx];
    ++idx;
  }

  /* evaluation */
  int order = cxs[idx].size();
  Eigen::VectorXd ax(order - 2), ay(order - 2), az(order - 2);

  /* coef of vel */
  for (int i = 0; i < order - 2; ++i)
  {
    ax(i) = double((i + 2) * (i + 1)) * cxs[idx][order - 3 - i];
    ay(i) = double((i + 2) * (i + 1)) * cys[idx][order - 3 - i];
    az(i) = double((i + 2) * (i + 1)) * czs[idx][order - 3 - i];
  }
  double ts = t;
  Eigen::VectorXd tv(order - 2);
  for (int i = 0; i < order - 2; ++i)
    tv(i) = pow(ts, i);

  Eigen::Vector3d acc;
  acc(0) = tv.dot(ax), acc(1) = tv.dot(ay), acc(2) = tv.dot(az);
  return acc;
}

#endif